{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# JASCO\n",
    "Welcome to JASCO's demo jupyter notebook. \n",
    "Here you will find a self-contained example of how to use JASCO for temporally controlled music generation.\n",
    "\n",
    "You can choose a model from the following selection:\n",
    "1. facebook/jasco-chords-drums-400M - 10s music generation conditioned on text, chords and drums, 400M parameters\n",
    "2. facebook/jasco-chords-drums-1B - 10s music generation conditioned on text, chords and drums, 1B parameters\n",
    "3. facebook/jasco-chords-drums-melody-400M - 10s music generation conditioned on text, chords, drums and melody, 400M parameters\n",
    "4. facebook/jasco-chords-drums-melody-1B - 10s music generation conditioned on text, chords, drums and melody, 1B parameters\n",
    "\n",
    "First, we start by initializing the JASCO model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "from audiocraft.models import JASCO\n",
    "\n",
    "model = JASCO.get_pretrained('facebook/jasco-chords-drums-melody-400M', chords_mapping_path='../assets/chord_to_index_mapping.pkl')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let us configure the generation parameters. Specifically, you can control the following:\n",
    "* `cfg_coef_all` (float, optional): Coefficient used for classifier free guidance - fully conditional term. \n",
    "                                    Defaults to 5.0.\n",
    "* `cfg_coef_txt` (float, optional): Coefficient used for classifier free guidance - additional text conditional term. \n",
    "                                    Defaults to 0.0.\n",
    "\n",
    "When left unchanged, JASCO will revert to its default parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.set_generation_params(\n",
    "    cfg_coef_all=0.0,\n",
    "    cfg_coef_txt=5.0\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we can go ahead and start generating music given textual prompts."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Text-conditional Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from audiocraft.utils.notebook import display_audio\n",
    "\n",
    "# set textual prompt\n",
    "text = \"Funky groove with electric piano playing blue chords rhythmically\"\n",
    "\n",
    "# run the model\n",
    "print(\"Generating...\")              \n",
    "output = model.generate(descriptions=[text], progress=True)\n",
    "\n",
    "# display the result\n",
    "print(f\"Text: {text}\\n\")\n",
    "display_audio(output, sample_rate=model.compression_model.sample_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can start adding temporal controls! We begin with conditioning on chord progressions:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Chords-conditional Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.set_generation_params(\n",
    "    cfg_coef_all=1.5,\n",
    "    cfg_coef_txt=3.0\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from audiocraft.utils.notebook import display_audio\n",
    "\n",
    "# set textual prompt\n",
    "text = \"Strings, woodwind, orchestral, symphony.\"\n",
    "\n",
    "# define chord progression\n",
    "chords = [('C', 0.0), ('D', 2.0), ('F', 4.0), ('Ab', 6.0), ('Bb', 7.0), ('C', 8.0)]\n",
    "\n",
    "# run the model\n",
    "print(\"Generating...\")\n",
    "output = model.generate_music(descriptions=[text], chords=chords, progress=True)\n",
    "\n",
    "# display the result\n",
    "print(f'Text: {text}')\n",
    "print(f'Chord progression: {chords}')\n",
    "display_audio(output, sample_rate=model.compression_model.sample_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we can condition the generation on drum tracks:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Drums-conditional Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchaudio\n",
    "from audiocraft.utils.notebook import display_audio\n",
    "\n",
    "\n",
    "# load drum prompt\n",
    "drums_waveform, sr = torchaudio.load(\"../assets/sep_drums_1.mp3\")\n",
    "\n",
    "# set textual prompt \n",
    "text = \"distortion guitars, heavy rock, catchy beat\"\n",
    "\n",
    "# run the model\n",
    "print(\"Generating...\")\n",
    "output = model.generate_music(\n",
    "    descriptions=[text],\n",
    "    drums_wav=drums_waveform,\n",
    "    drums_sample_rate=sr,\n",
    "    progress=True\n",
    ")\n",
    "\n",
    "# display the result\n",
    "print('drum prompt:')\n",
    "display_audio(drums_waveform, sample_rate=sr)\n",
    "print(f'Text: {text}')\n",
    "display_audio(output, sample_rate=model.compression_model.sample_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also combine multiple temporal controls! Let's move on to generating with both chords and drums conditioning:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Drums + Chords conditioning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchaudio\n",
    "from audiocraft.utils.notebook import display_audio\n",
    "\n",
    "\n",
    "# load drum prompt\n",
    "drums_waveform, sr = torchaudio.load(\"../assets/sep_drums_1.mp3\")\n",
    "\n",
    "# set textual prompt \n",
    "text = \"string quartet, orchestral, dramatic\"\n",
    "\n",
    "# define chord progression\n",
    "chords = [('C', 0.0), ('D', 2.0), ('F', 4.0), ('Ab', 6.0), ('Bb', 7.0), ('C', 8.0)]\n",
    "\n",
    "# run the model\n",
    "print(\"Generating...\")\n",
    "output = model.generate_music(\n",
    "    descriptions=[text],\n",
    "    drums_wav=drums_waveform,\n",
    "    drums_sample_rate=sr,\n",
    "    chords=chords,\n",
    "    progress=True\n",
    ")\n",
    "\n",
    "# display the result\n",
    "print('drum prompt:')\n",
    "display_audio(drums_waveform, sample_rate=sr)\n",
    "print(f'Chord progression: {chords}')\n",
    "print(f'Text: {text}')\n",
    "display_audio(output, sample_rate=model.compression_model.sample_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Melody + Drums + Chords conditioning - inference example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "from demucs import pretrained\n",
    "from demucs.apply import apply_model\n",
    "from demucs.audio import convert_audio\n",
    "import torch\n",
    "from audiocraft.utils.notebook import display_audio\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# --------------------------\n",
    "# First, choose file to load\n",
    "# --------------------------\n",
    "fnames = ['salience_1', 'salience_2']\n",
    "chords = [\n",
    "    [('N',  0.0), ('Eb7',  1.088000000), ('C#',  4.352000000), ('D',  4.864000000), ('Dm7',  6.720000000), ('G7',  8.256000000), ('Am7b5/G',  9.152000000)],  # for salience 1\n",
    "    [('N',  0.0), ('C',  0.320000000), ('Dm7',  3.456000000), ('Am',  4.608000000), ('F',  8.320000000), ('C',  9.216000000)]  # for salience 2\n",
    "]\n",
    "file_idx = 0  # either 0 or 1\n",
    "\n",
    "\n",
    "# ------------------------------------\n",
    "# display audio, melody map and chords\n",
    "# ------------------------------------\n",
    "def plot_chromagram(tensor):\n",
    "    # Check if tensor is a PyTorch tensor\n",
    "    if not torch.is_tensor(tensor):\n",
    "        raise ValueError('Input should be a PyTorch tensor')\n",
    "    tensor = tensor.numpy().T # C, T\n",
    "    plt.figure(figsize=(20, 20))\n",
    "    plt.imshow(tensor, cmap='binary', interpolation='nearest', origin='lower')\n",
    "    plt.show()\n",
    "\n",
    "# load salience and display the corresponding wav\n",
    "melody_prompt_wav, melody_prompt_sr = torchaudio.load(f\"../assets/{fnames[file_idx]}.wav\")\n",
    "print(\"Source melody:\")\n",
    "display_audio(melody_prompt_wav, sample_rate=melody_prompt_sr)\n",
    "melody =  torch.load(f\"../assets/{fnames[file_idx]}.th\", weights_only=True)\n",
    "plot_chromagram(melody)\n",
    "print(\"Chords:\")\n",
    "print(chords[file_idx])\n",
    "\n",
    "# --------------------------------------------------\n",
    "# use demucs to seperate the drums stem from src mix\n",
    "# --------------------------------------------------\n",
    "def _get_drums_stem(wav: torch.Tensor, sample_rate: int) -> torch.Tensor:\n",
    "    \"\"\"Get parts of the wav that holds the drums, extracting the main stems from the wav.\"\"\"\n",
    "    demucs_model = pretrained.get_model('htdemucs').to('cuda')\n",
    "    wav = convert_audio(\n",
    "        wav, sample_rate, demucs_model.samplerate, demucs_model.audio_channels)  # type: ignore\n",
    "    stems = apply_model(demucs_model, wav.cuda().unsqueeze(0), device='cuda').squeeze(0)\n",
    "    drum_stem = stems[demucs_model.sources.index('drums')]  # extract relevant stems for drums conditioning\n",
    "    return convert_audio(drum_stem.cpu(), demucs_model.samplerate, sample_rate, 1)  # type: ignore\n",
    "drums_wav = _get_drums_stem(melody_prompt_wav, melody_prompt_sr)\n",
    "print(\"Separated drums:\")\n",
    "display_audio(drums_wav, sample_rate=melody_prompt_sr)\n",
    "\n",
    "# ----------------------------------\n",
    "# Generate using the loaded controls\n",
    "# ----------------------------------\n",
    "# these are free-form texts written randomly\n",
    "texts = [\n",
    "    '90s rock with heavy drums and hammond',\n",
    "    '80s pop with groovy synth bass and drum machine',\n",
    "    'folk song with leading accordion',\n",
    "]\n",
    "\n",
    "print(\"Generating...\")\n",
    "# replacing dynammic solver with simple euler solver\n",
    "model.set_generation_params(cfg_coef_all=1.5, cfg_coef_txt=2.5, euler=True, euler_steps=50)  # manually set with euler solver\n",
    "output = model.generate_music(\n",
    "    descriptions=texts,\n",
    "    chords=chords[file_idx],\n",
    "    drums_wav=drums_wav,\n",
    "    drums_sample_rate=melody_prompt_sr,\n",
    "    melody_salience_matrix=melody.permute(1, 0),\n",
    "    progress=True\n",
    ")\n",
    "display_audio(output, sample_rate=model.compression_model.sample_rate)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jasco_dev",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
